{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [针对专业人员的 TensorFlow 2.0 入门](https://www.tensorflow.org/tutorials/quickstart/advanced?hl=zh-cn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.5.0\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "print(tf.__version__)\n",
    "\n",
    "from tensorflow.keras.layers import Dense, Conv2D, Flatten\n",
    "from tensorflow.keras import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 28, 28) (10000, 28, 28)\n"
     ]
    }
   ],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "\n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "X_train, X_test = X_train / 255.0, X_test / 255.0\n",
    "print(X_train.shape, X_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 28, 28, 1) (10000, 28, 28, 1)\n"
     ]
    }
   ],
   "source": [
    "X_train = X_train[..., tf.newaxis].astype(\"float32\")\n",
    "X_test = X_test[..., tf.newaxis].astype(\"float32\")\n",
    "print(X_train.shape, X_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (X_train, y_train)).shuffle(10000).batch(32)\n",
    "\n",
    "test_ds =  tf.data.Dataset.from_tensor_slices(\n",
    "    (X_test, y_test)).batch(32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<__main__.MyModel object at 0x7fc182007280>\n"
     ]
    }
   ],
   "source": [
    "class MyModel(Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = Conv2D(32, 3, activation=\"relu\")\n",
    "        self.flatten = Flatten()\n",
    "        self.d1 = Dense(128, activation='relu')\n",
    "        self.d2 = Dense(10)\n",
    "    \n",
    "    def call(self, x):\n",
    "        x = self.conv1(x)\n",
    "        return self.d2(self.d1(self.flatten(x)))\n",
    "\n",
    "\n",
    "# Create an instance of the model\n",
    "model = MyModel()\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "optimizer = tf.keras.optimizers.Adam()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loss = tf.keras.metrics.Mean(name=\"train_loss\")\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name=\"train_accuracy\")\n",
    "test_loss = tf.keras.metrics.Mean(name=\"test_loss\")\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name=\"test_accuracy\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(images, labels):\n",
    "    with tf.GradientTape() as tape:\n",
    "        # training=True is only needed if there are layers with different\n",
    "        # behavior during training versus inference (e.g. Dropout).\n",
    "        predictions = model(images, training=True)\n",
    "        loss =  loss_fn(labels, predictions)\n",
    "    gradients = tape.gradient(loss, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "    train_loss(loss)\n",
    "    train_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def test_step(images, labels):\n",
    "    # training=False is only needed if there are layers with different\n",
    "    # behavior during training versus inference (e.g. Dropout).\n",
    "    predictions = model(images, training=False)\n",
    "    t_loss = loss_fn(labels, predictions)\n",
    "\n",
    "    test_loss(t_loss)\n",
    "    test_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.14193545281887054, Accuracy: 95.69%, Test Loss: 0.060119640082120895, Test Accuracy: 98.01%\n",
      "Epoch 2, Loss: 0.04455890133976936, Accuracy: 98.64%, Test Loss: 0.059963829815387726, Test Accuracy: 98.08%\n",
      "Epoch 3, Loss: 0.02440371923148632, Accuracy: 99.22%, Test Loss: 0.05960312858223915, Test Accuracy: 98.20%\n",
      "Epoch 4, Loss: 0.01419108361005783, Accuracy: 99.53%, Test Loss: 0.06092323362827301, Test Accuracy: 98.41%\n",
      "Epoch 5, Loss: 0.009066925384104252, Accuracy: 99.69%, Test Loss: 0.06790931522846222, Test Accuracy: 98.20%\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 5\n",
    "for epoch in range(EPOCHS):\n",
    "    train_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "    test_loss.reset_states()\n",
    "    test_accuracy.reset_states()\n",
    "\n",
    "    for images, labels in train_ds:\n",
    "        train_step(images, labels)\n",
    "    \n",
    "    for test_images, test_labels in test_ds:\n",
    "        test_step(test_images, test_labels)\n",
    "\n",
    "    print(\n",
    "        f'Epoch {epoch + 1}, '\n",
    "        f'Loss: {train_loss.result()}, '\n",
    "        f'Accuracy: {train_accuracy.result() * 100:.2f}%, '\n",
    "        f'Test Loss: {test_loss.result()}, '\n",
    "        f'Test Accuracy: {test_accuracy.result() * 100:.2f}%'\n",
    "    )\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
